-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add model_to_minibatch
transformation to convert all pm.Data
to pm.Minibatch
#7785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
You can use the lower level utility: pymc/pymc/variational/minibatch_rv.py Line 53 in ef26ae8
Then make that a vanilla observed RV |
Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded |
My real issue was not understanding what needs to be the key and value in the replacements, between:
|
the best is usual to replace the whole fgraph |
I don't really understand what that answer means |
dprint the fgraph and it will perhaps be more obvious what I am mumbling |
The problem i was running into was that I ended up with two |
Because Minibatch assumes the data variables have the same length, it might make sense to take a variables argument. Or have some way to group data variables of the same size (same dim name maybe?)
…On Thu, 15 May 2025, 15:35 Ricardo Vieira, ***@***.***> wrote:
*ricardoV94* left a comment (pymc-devs/pymc#7785)
<#7785 (comment)>
dprint the fgraph and it will perhaps be more obvious what I am mumbling
—
Reply to this email directly, view it on GitHub
<#7785 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUMC5VCN6VAAJKNHEMT26SJPZAVCNFSM6AAAAAB5F7LYYKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDQOBTHAZTINZXG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
c1168de
to
8d1b479
Compare
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) | ||
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} | ||
assert 0 | ||
# Add total_size to all observed RVs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should only add to those that depend on the minibatch data no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correct thing would be a dim analysis like we do for MarginaModel to confirm the first dim of the data maps to the first dim of the observed rvs, which is when the rewrite is valid. We may not want to do that, but we should be clear about the assumptions in the docstrings.
Example where minibatch rewrite will fail / do the wrong thing, is if you tranpose the data before you used it in the observations.
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} | ||
assert 0 | ||
# Add total_size to all observed RVs | ||
total_size = data_vars[0].get_value().shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total size can be symbolic I think?
|
||
data_vars = [ | ||
memo[datum].owner.inputs[0] | ||
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a model.data_vars
. You should however allow users to specify which data vars to be minibatched (default to all is fine). Alternatively we could restrict this to models with dims, and the user has to tell us which dim is being minibatched?
That makes the graph analysis easier
Yep, I have reworked this code and need to push my changes!
…On Wed, 11 Jun 2025, 23:07 Ricardo Vieira, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In pymc/model/transform/basic.py
<#7785 (comment)>:
> @@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
return [model[var] if isinstance(var, str) else var for var in vars_seq]
+def model_to_minibatch(model: Model, batch_size: int) -> Model:
+ """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs."""
+ from pymc.variational.minibatch_rv import create_minibatch_rv
+
+ fgraph, memo = fgraph_from_model(model, inlined_views=True)
+
+ # obs_rvs, data_vars = model.rvs_to_values.items()
+
+ data_vars = [
+ memo[datum].owner.inputs[0]
+ for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
There's a model.data_vars. You should however allow users to specify
which data vars to be minibatched (default to all is fine). Alternatively
we could restrict this to models with dims, and the user has to tell us
which dim is being minibatched?
That makes the graph analysis easier
—
Reply to this email directly, view it on GitHub
<#7785 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUKBANF33XOQRR2ISCD3DCK7PAVCNFSM6AAAAAB5F7LYYKVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDSMJYG42DKNRZGA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Description
A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with
pm.Minibatch
on the data.This PR adds a model transformation that does that for the user. It's the reverse of the
remove_minibatched_nodes
transformer that @zaxtax implemented recently.This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the
total_size
set correctly. Help wanted.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/